In [1]:
from poke_env.player import Gen8EnvSinglePlayer
from poke_env.data import GenData
from gymnasium.spaces import Space, Box, Dict
import numpy as np

class SimpleRLPlayer(Gen8EnvSinglePlayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    '''def calc_reward(self, last_battle, current_battle) -> float:
        """
        Calculate reward based on battle state changes
        """
        return self.reward_computing_helper(
            current_battle, 
            fainted_value=2.0, 
            hp_value=1.0, 
            victory_value=30.0
        )'''

    '''def calc_reward(self, last_battle, current_battle) -> float:
        """
        Calculate reward based on battle state changes
        """
        if current_battle.won:
            return 30.0
        
        # Get number of fainted Pokemon for both sides
        n_fainted_mon = len([mon for mon in current_battle.team.values() if mon.fainted])
        n_fainted_opp = len([mon for mon in current_battle.opponent_team.values() if mon.fainted])
        
        # Calculate HP percentages
        total_hp_fraction = sum(mon.current_hp_fraction for mon in current_battle.team.values() if not mon.fainted)
        total_opp_hp_fraction = sum(mon.current_hp_fraction for mon in current_battle.opponent_team.values() if not mon.fainted)
        
        return (n_fainted_opp - n_fainted_mon) * 5.0 + (total_hp_fraction - total_opp_hp_fraction) * 2.0'''
        
    def calc_reward(self, last_battle, current_battle) -> float:
        """
        Calculate reward based on battle state changes
        """
        # Get the type chart for Generation 8
        type_chart = GenData.from_gen(8).type_chart
        
        # Base reward from winning/losing
        if current_battle.won:
            return 500.0
        
        # Get number of fainted Pokemon for both sides
        n_fainted_mon = len([mon for mon in current_battle.team.values() if mon.fainted])
        n_fainted_opp = len([mon for mon in current_battle.opponent_team.values() if mon.fainted])
        
        # Calculate HP percentages
        total_hp_fraction = sum(mon.current_hp_fraction for mon in current_battle.team.values() if not mon.fainted)
        total_opp_hp_fraction = sum(mon.current_hp_fraction for mon in current_battle.opponent_team.values() if not mon.fainted)
        
        # Calculate base reward
        reward = (n_fainted_opp - n_fainted_mon) * 5.0 + (total_hp_fraction - total_opp_hp_fraction) * 2.0

        # Move Effectiveness Reward
        if current_battle.active_pokemon and current_battle.opponent_active_pokemon:
            # Calculate effectiveness for each available move
            for move in current_battle.available_moves:
                if move.type:
                    multiplier = move.type.damage_multiplier(
                        current_battle.opponent_active_pokemon.type_1,
                        current_battle.opponent_active_pokemon.type_2,
                        type_chart=type_chart
                    )
                    
                    # Reward based on having super effective moves available
                    if multiplier > 1:  # Super effective
                        reward += 2.0
                    elif multiplier < 1:  # Not very effective
                        reward -= 1.0
                    
                    # Additional STAB bonus
                    if move.type in [current_battle.active_pokemon.type_1, 
                                current_battle.active_pokemon.type_2]:
                        reward += 0.5

        # Switch Timing Reward with Dual-Type Consideration
        if last_battle and current_battle.active_pokemon != last_battle.active_pokemon:  # Pokemon switched
            if last_battle.opponent_active_pokemon and current_battle.opponent_active_pokemon:
                def calculate_matchup_score(attacker, defender):
                    """Calculate average effectiveness of defender's moves against attacker"""
                    if not defender or not attacker:
                        return 0
                    
                    total_multiplier = 0
                    count = 0
                    
                    # Consider defender's type 1
                    if defender.type_1:
                        if attacker.type_1:
                            total_multiplier += defender.type_1.damage_multiplier(
                                attacker.type_1,
                                attacker.type_2,
                                type_chart=type_chart
                            )
                            count += 1
                    
                    # Consider defender's type 2 if it exists
                    if defender.type_2:
                        if attacker.type_1:
                            total_multiplier += defender.type_2.damage_multiplier(
                                attacker.type_1,
                                attacker.type_2,
                                type_chart=type_chart
                            )
                            count += 1
                    
                    return total_multiplier / count if count > 0 else 0
                
                # Calculate old and new matchup scores
                old_matchup = calculate_matchup_score(
                    last_battle.active_pokemon,
                    last_battle.opponent_active_pokemon
                )
                
                new_matchup = calculate_matchup_score(
                    current_battle.active_pokemon,
                    current_battle.opponent_active_pokemon
                )
                
                # Reward based on matchup improvement
                matchup_diff = old_matchup - new_matchup
                if matchup_diff > 0.5:  # Significant improvement
                    reward += 3.0
                elif matchup_diff > 0:  # Slight improvement
                    reward += 1.5
                elif matchup_diff < -0.5:  # Significant worsening
                    reward -= 2.0
                elif matchup_diff < 0:  # Slight worsening
                    reward -= 1.0

        # Team Health Management Reward
        healthy_count = 0
        critical_count = 0
        for mon in current_battle.team.values():
            if not mon.fainted:
                if mon.current_hp_fraction > 0.5:
                    healthy_count += 1
                elif mon.current_hp_fraction < 0.25:
                    critical_count += 1
        
        # Reward for keeping team healthy
        reward += healthy_count * 1.0  # Bonus for each healthy Pokemon
        reward -= critical_count * 1.5  # Penalty for each Pokemon in critical health
        
        # Add penalty for losing
        if current_battle.lost:
            reward -= 500.0
        
        return reward
    
    

    def embed_battle(self, battle):
        """
        Create observation embedding from battle state
        """
        # -1 indicates that the move does not have a base power or is not available
        moves_base_power = -np.ones(4)      # Default to -1 for unavailable moves
        moves_dmg_multiplier = np.ones(4)   # Default to neutral effectiveness (1x)
        moves_accuracy = np.zeros(4)        # Move accuracy
        moves_category = np.zeros(4 * 3)    # One-hot encoding for move category (Physical/Special/Status)

        
        # Get the type chart for Generation 8
        type_chart = GenData.from_gen(8).type_chart
        
        # Process each available move
        for i, move in enumerate(battle.available_moves):
            # Normalize move base power to [0,1] range by dividing by 100
            moves_base_power[i] = (move.base_power / 100)
            
            # Move accuracy (normalized)
            moves_accuracy[i] = move.accuracy if move.accuracy is not None else 1.0
            
            # Move category (1=Physical, 2=Special, 3=Status)
            moves_category[i*3 + (move.category.value - 1)] = 1  # One-hot encoding for category
            
    
            
            # Calculate type effectiveness
            if move.type and battle.opponent_active_pokemon:
                try:
                    moves_dmg_multiplier[i] = move.type.damage_multiplier(
                        battle.opponent_active_pokemon.type_1,
                        battle.opponent_active_pokemon.type_2,
                        type_chart=type_chart
                    )
                    
                    # Add STAB bonus
                    if battle.active_pokemon and move.type in [
                        battle.active_pokemon.type_1,
                        battle.active_pokemon.type_2
                    ]:
                        moves_dmg_multiplier[i] *= 1.5
                        
                except Exception as e:
                    print(f"Error in damage multiplier: {e}")

        def encode_pokemon(pokemon):
            if not pokemon or pokemon.stats is None:
                return np.zeros(82)  # Correct size for padding
            
                
            # Encode types
            type_1_vector = np.zeros(18)
            type_2_vector = np.zeros(18)
            
            if pokemon.type_1:
                type_1_vector[pokemon.type_1.value - 1] = 1
            if pokemon.type_2:  # type_2 can be None
                type_2_vector[pokemon.type_2.value - 1] = 1
                
            # Encode base stats (normalized)
            base_stats = np.array([
                pokemon.base_stats.get('hp', 0),
                pokemon.base_stats.get('atk', 0),
                pokemon.base_stats.get('def', 0),
                pokemon.base_stats.get('spa', 0),
                pokemon.base_stats.get('spd', 0),
                pokemon.base_stats.get('spe', 0)
            ], dtype=np.float32) / 255.0  # Max base stat is 255
            
            # Encode current stats (normalized)
            current_stats = np.array([
                pokemon.stats.get('hp', 0),
                pokemon.stats.get('atk', 0),
                pokemon.stats.get('def', 0),
                pokemon.stats.get('spa', 0),
                pokemon.stats.get('spd', 0),
                pokemon.stats.get('spe', 0)
            ], dtype=np.float32) / 400.0  # Typical max stat around 400
            
            # Ensure no NaN values in current_stats
            current_stats = np.nan_to_num(current_stats)
            
            # Encode stat boosts (-6 to +6)
            boosts = np.array([
                pokemon.boosts.get('atk', 0),
                pokemon.boosts.get('def', 0),
                pokemon.boosts.get('spa', 0),
                pokemon.boosts.get('spd', 0),
                pokemon.boosts.get('spe', 0),
                pokemon.boosts.get('accuracy', 0),
                pokemon.boosts.get('evasion', 0)
            ], dtype=np.float32) / 6.0
            
            # Ensure no NaN values in boosts
            boosts = np.nan_to_num(boosts)
            
            # Encode moves (if known)
            moves_vector = np.zeros(4 * 5)  # base_power, accuracy, and 3 category values for each move
            for i, (_, move) in enumerate(pokemon.moves.items()):
                if i < 4:  # Only consider up to 4 moves
                    base_idx = i * 5  # Each move now has 5 values
                    moves_vector[base_idx] = move.base_power / 100.0      # Base power
                    moves_vector[base_idx + 1] = move.accuracy if move.accuracy is not None else 1.0  # Accuracy
                    # One-hot encoding for category (Physical=1, Special=2, Status=3)
                    category_idx = move.category.value - 1  # Convert to 0-based index
                    moves_vector[base_idx + 2 + category_idx] = 1  # Set the appropriate category bit
                    
            # Ensure no NaN values in moves_vector
            moves_vector = np.nan_to_num(moves_vector)

            
            # Encode status
            status_vector = np.zeros(7)
            if pokemon.status:  # status can be None
                status_vector[pokemon.status.value - 1] = 1
                
                
                
            return np.concatenate([
                type_1_vector,      # 18 values
                type_2_vector,      # 18 values
                base_stats,         # 6 values
                current_stats,      # 6 values
                boosts,            # 7 values
                moves_vector,       # 20 values (4 moves * 5 values each)
                status_vector      # 7 values
            ])

        # Encode active Pokemon
        active_pokemon_vector = encode_pokemon(battle.active_pokemon)
        opponent_pokemon_vector = encode_pokemon(battle.opponent_active_pokemon)
        
        # Encode party Pokemon (excluding active)
        team_vectors = []
        opponent_vectors = []
        
        for mon in battle.team.values():
            if mon != battle.active_pokemon:
                team_vectors.append(encode_pokemon(mon))
        
        for mon in battle.opponent_team.values():
            if mon != battle.opponent_active_pokemon:
                opponent_vectors.append(encode_pokemon(mon))
        
        # Pad to 5 additional Pokemon per side
        while len(team_vectors) < 5:
            team_vectors.append(np.zeros(82))  # Size of pokemon encoding
        while len(opponent_vectors) < 5:
            opponent_vectors.append(np.zeros(82))
            
        team_vector = np.concatenate(team_vectors)
        opponent_vector = np.concatenate(opponent_vectors)

        # Weather conditions (9 possible values + turn information)
        def encode_weather():
            weather_vector = np.zeros(9)  # Base weather conditions
            weather_turn = np.zeros(9)    # Turn information normalized
            
            if battle.weather:
                for weather, turn in battle.weather.items():
                    weather_idx = weather.value - 1
                    weather_vector[weather_idx] = 1
                    weather_turn[weather_idx] = turn / 100.0  # Normalize turn number
            return np.concatenate([weather_vector, weather_turn])

        # Field conditions (13 possible values + turn information)
        def encode_field():
            field_vector = np.zeros(13)   # Base field conditions
            field_turn = np.zeros(13)     # Turn information normalized
            
            if battle.fields:
                for field, turn in battle.fields.items():
                    field_idx = field.value - 1
                    field_vector[field_idx] = 1
                    field_turn[field_idx] = turn / 100.0  # Normalize turn number
            return np.concatenate([field_vector, field_turn])

        # Side conditions (24 possible values + stack/turn information)
        def encode_side_conditions(side_conditions):
            condition_vector = np.zeros(24)    # Base conditions
            condition_value = np.zeros(24)     # Stack count or turn number
            
            if side_conditions:
                for condition, value in side_conditions.items():
                    condition_idx = condition.value - 1
                    condition_vector[condition_idx] = 1
                    condition_value[condition_idx] = value / 100.0  # Normalize value
            return np.concatenate([condition_vector, condition_value])

        # Get condition vectors
        weather_vector = encode_weather()           # 18 values (9 conditions + 9 turns)
        field_vector = encode_field()              # 26 values (13 conditions + 13 turns)
        our_side_vector = encode_side_conditions(battle.side_conditions)         # 48 values
        opponent_side_vector = encode_side_conditions(battle.opponent_side_conditions)  # 48 values

        # Final vector with all components
        final_vector = np.concatenate([
            moves_base_power,           # 4 values
            moves_dmg_multiplier,       # 4 values
            moves_accuracy,             # 4 values
            moves_category,            # 12 values (4 moves * 3 categories)
            active_pokemon_vector,      # 74 values
            opponent_pokemon_vector,    # 74 values
            team_vector,               # 370 values (5 pokemon * 74)
            opponent_vector,           # 370 values (5 pokemon * 74)
            weather_vector,             # 18 values (9 conditions + 9 turns)
            field_vector,               # 26 values (13 conditions + 13 turns)
            our_side_vector,            # 48 values (24 conditions + 24 stacks/turns)
            opponent_side_vector        # 48 values (24 conditions + 24 stacks/turns)
        ])
        
            # Debugging: Check for NaN or Inf in the final vector
        if np.isnan(final_vector).any() or np.isinf(final_vector).any():
            print("NaN or Inf detected in final observation vector")
            print("Moves base power:", moves_base_power)
            print("Moves damage multiplier:", moves_dmg_multiplier)
            print("Moves accuracy:", moves_accuracy)
            print("Moves category:", moves_category)
            print("Active pokemon vector:", active_pokemon_vector)
            print("Opponent pokemon vector:", opponent_pokemon_vector)
            print("Team vector:", team_vector)
            print("Opponent vector:", opponent_vector)
            print("Weather vector:", weather_vector)
            print("Field vector:", field_vector)
            print("Our side vector:", our_side_vector)
            print("Opponent side vector:", opponent_side_vector)
        
        return np.float32(final_vector)

    def describe_embedding(self) -> Space:
        """
        Describe the observation space
        """
        n_moves = 4
        n_pokemon_features = 82  # Updated size
        n_weather = 9
        n_field = 13
        n_side = 24

        # Lower bounds
        low = np.concatenate([
            -np.ones(n_moves),          # Move base power
            np.zeros(n_moves),          # Move damage multiplier
            np.zeros(n_moves),          # Move accuracy
            np.zeros(n_moves * 3),      # Move categories (4 moves * 3 categories)
            np.zeros(n_pokemon_features),    # Active pokemon
            np.zeros(n_pokemon_features),    # Opponent active pokemon
            np.zeros(n_pokemon_features * 5),# Team pokemon
            np.zeros(n_pokemon_features * 5),# Opponent team pokemon
            np.zeros(n_weather * 2),    # Weather conditions + turns
            np.zeros(n_field * 2),      # Field conditions + turns
            np.zeros(n_side * 2),       # Our side conditions + stacks/turns
            np.zeros(n_side * 2)        # Opponent side conditions + stacks/turns
        ])

        # Upper bounds
        high = np.concatenate([
            3 * np.ones(n_moves),       # Move base power
            4 * np.ones(n_moves),       # Move damage multiplier
            np.ones(n_moves),           # Move accuracy
            np.ones(n_moves * 3),       # Move categories (4 moves * 3 categories)
            np.ones(n_pokemon_features),     # Active pokemon
            np.ones(n_pokemon_features),     # Opponent active pokemon
            np.ones(n_pokemon_features * 5), # Team pokemon
            np.ones(n_pokemon_features * 5), # Opponent team pokemon
            np.concatenate([np.ones(n_weather), np.ones(n_weather)]),     # Weather + turns
            np.concatenate([np.ones(n_field), np.ones(n_field)]),        # Field + turns
            np.concatenate([np.ones(n_side), np.ones(n_side)]),          # Our side + stacks/turns
            np.concatenate([np.ones(n_side), np.ones(n_side)])           # Opponent side + stacks/turns
        ])

        return Box(
            np.array(low, dtype=np.float32),
            np.array(high, dtype=np.float32),
            dtype=np.float32
        )

    '''def action_to_move(self, action: int, battle) -> str:
        """
        Converts actions to move orders.
        
        The conversion is done as follows:
        action = -1: Forfeit
        0 <= action < 4: Regular moves
        4 <= action < 8: Z-moves (not used in Gen8)
        8 <= action < 12: Mega Evolution (not used in Gen8)
        12 <= action < 16: Dynamax moves
        16 <= action < 22: Switches
        """
        if action == -1:
            return self.choose_random_move(battle)
            
        # Regular moves
        if 0 <= action < 4:
            try:
                return self.create_order(battle.available_moves[action])
            except IndexError:
                return self.choose_random_move(battle)
                
        # Dynamax moves
        elif 12 <= action < 16 and battle.can_dynamax:
            try:
                move = battle.available_moves[action - 12]
                return self.create_order(move, dynamax=True)
            except IndexError:
                return self.choose_random_move(battle)
                
        # Switches
        elif 16 <= action < 22:
            try:
                switches = list(battle.available_switches)
                if switches:
                    return self.create_order(switches[action - 16])
            except IndexError:
                pass
                
        # If no legal action was found, choose a random move
        return self.choose_random_move(battle)'''
        
from poke_env.player import Player
class BestBasePowerPlayer(Player):
    def choose_move(self, battle):
        # Example: Use a simple heuristic to select the move
        if battle.available_moves:
            # Choose the strongest move
            best_move = max(
                battle.available_moves, 
                key=lambda move: move.base_power
            )
            return self.create_order(best_move)
        else:
            # If no move is available, choose a random one
            return self.choose_random_move(battle)
        

from poke_env.player import Player
from poke_env.data import GenData

class SmartDamagePlayer(Player):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.type_chart = GenData.from_gen(8).type_chart

    def choose_move(self, battle):
        if battle.available_moves:
            # Choose move based on power and type effectiveness
            best_move = max(
                battle.available_moves,
                key=lambda move: (
                    move.base_power * 
                    self._calculate_multiplier(move, battle)
                )
            )
            return self.create_order(best_move)
        else:
            return self.choose_random_move(battle)
    
    def _calculate_multiplier(self, move, battle):
        if not move.type:
            return 0
        
        multiplier = move.type.damage_multiplier(
            battle.opponent_active_pokemon.type_1,
            battle.opponent_active_pokemon.type_2,
            type_chart=self.type_chart
        )
        
        # Add STAB bonus
        if battle.active_pokemon and move.type in [
            battle.active_pokemon.type_1,
            battle.active_pokemon.type_2
        ]:
            multiplier *= 1.5
            
        return multiplier
    
# Strategic Player: Switches, Dynamax, Z-Moves, Mega Evolution

In [2]:
from poke_env.player.random_player import RandomPlayer
# Create the environment
opponent = RandomPlayer(battle_format="gen8randombattle")
env = SimpleRLPlayer(
    battle_format="gen8randombattle",
    opponent=opponent,
    start_challenging=True
)

try:
    # Training loop
    n_battles = 1  # Number of battles to run
    for _ in range(n_battles):
        done = False
        obs = env.reset()  # Reset for new battle
        
        while not done:    
            # Sample a random action
            action = env.action_space.sample()
            
            # Take step in environment
            obs, reward, done, truncated, info = env.step(action)
            print(f"Observation: {obs}")
            print(f"Action: {action}, Reward: {reward}")
        
        print("Battle finished!")
        
finally:
    # Properly close the environment
    env.close()

Observation: [0.95 0.   0.9  ... 0.   0.   0.  ]
Action: 19, Reward: 19.5
Observation: [0. 1. 1. ... 0. 0. 0.]
Action: 5, Reward: 17.5
Observation: [ 1. -1. -1. ...  0.  0.  0.]
Action: 7, Reward: 15.909940828402366
Observation: [1.2 1.2 0.  ... 0.  0.  0. ]
Action: 1, Reward: 17.408728707190246
Observation: [1.2 1.2 0.  ... 0.  0.  0. ]
Action: 15, Reward: 18.488122646584184
Observation: [0.95 0.95 0.   ... 0.   0.   0.  ]
Action: 8, Reward: 17.807516585978124
Observation: [0.95 0.   0.9  ... 0.   0.   0.  ]
Action: 19, Reward: 11.807516585978124
Observation: [0.8 1.  0.5 ... 0.  0.  0. ]
Action: 18, Reward: 14.307516585978124
Observation: [0.8 1.  0.5 ... 0.  0.  0. ]
Action: 7, Reward: 16.567516585978126
Observation: [0.8 1.  0.5 ... 0.  0.  0. ]
Action: 9, Reward: 11.567516585978124
Observation: [0.8 1.  0.5 ... 0.  0.  0. ]
Action: 0, Reward: 11.685788190916394
Observation: [0.8 1.  0.5 ... 0.  0.  0. ]
Action: 0, Reward: 12.509244981039853
Observation: [0.8 1.  0.5 ... 0.  0.  0.

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

class PPOMemory:
    def __init__(self):
        self.states = []
        self.actions = []
        self.probs = []
        self.vals = []
        self.rewards = []
        self.dones = []

    def generate_batches(self, batch_size):
        n_states = len(self.states)
        batch_start = np.arange(0, n_states, batch_size)
        indices = np.arange(n_states, dtype=np.int64)
        np.random.shuffle(indices)
        batches = [indices[i:i+batch_size] for i in batch_start]
        
        return np.array(self.states),\
                np.array(self.actions),\
                np.array(self.probs),\
                np.array(self.vals),\
                np.array(self.rewards),\
                np.array(self.dones),\
                batches

    def store_memory(self, state, action, probs, vals, reward, done):
        self.states.append(state)
        self.actions.append(action)
        self.probs.append(probs)
        self.vals.append(vals)
        self.rewards.append(reward)
        self.dones.append(done)

    def clear_memory(self):
        self.states = []
        self.actions = []
        self.probs = []
        self.vals = []
        self.rewards = []
        self.dones = []

class ActorNetwork(nn.Module):
    def __init__(self, input_dims, n_actions):
        super(ActorNetwork, self).__init__()
        
        self.actor = nn.Sequential(
            # First layer - expand instead of compress
            nn.Linear(input_dims, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            # Gradually reduce size
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            
            nn.Linear(256, n_actions)
        )
        
        # Xavier initialization
        for layer in self.actor:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)
                nn.init.constant_(layer.bias, 0.01)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, state):
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
        x = self.actor(state)
        x = x - x.max(dim=-1, keepdim=True)[0]  # Subtract max for numerical stability
        return self.softmax(x)

class CriticNetwork(nn.Module):
    def __init__(self, input_dims):
        super(CriticNetwork, self).__init__()
        
        self.critic = nn.Sequential(
            nn.Linear(input_dims, 2048),
            nn.ReLU(),
            nn.LayerNorm(2048),
            nn.Dropout(0.1),
            
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.LayerNorm(1024),
            nn.Dropout(0.1),
            
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Dropout(0.1),
            
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            
            nn.Linear(256, 1)
        )
        
        # Initialize weights
        for layer in self.critic:
            if isinstance(layer, nn.Linear):
                nn.init.orthogonal_(layer.weight, gain=1)
                nn.init.constant_(layer.bias, 0)

    def forward(self, state):
        if len(state.shape) == 1:
            state = state.unsqueeze(0)
        return self.critic(state)

In [4]:
from torch.distributions import Categorical
class PPOAgent:
    def __init__(self, input_dims, n_actions, 
                 lr=0.0003, gamma=0.99, alpha=0.2, 
                 gae_lambda=0.95, policy_clip=0.2, 
                 batch_size=64, n_epochs=10):
        self.gamma = gamma
        self.policy_clip = policy_clip
        self.n_epochs = n_epochs
        self.gae_lambda = gae_lambda
        self.batch_size = batch_size

        self.actor = ActorNetwork(input_dims, n_actions)
        self.critic = CriticNetwork(input_dims)
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
        
        # Add gradient clipping
        self.max_grad_norm = 0.5
        
        self.memory = PPOMemory()
        
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.actor.to(self.device)
        self.critic.to(self.device)


    def remember(self, state, action, probs, vals, reward, done):
        self.memory.store_memory(state, action, probs, vals, reward, done)

    def choose_action(self, observation):
        self.actor.eval()  # Set to evaluation mode
        with torch.no_grad():  # Disable gradient computation
            state = torch.FloatTensor(observation).to(self.device)
            if len(state.shape) == 1:
                state = state.unsqueeze(0)
            
            # Get action probabilities
            probs = self.actor(state)
            
            # Debugging: Check for NaN values
            if torch.isnan(probs).any() or torch.isinf(probs).any():
                print("NaN or Inf detected in action probabilities")
                probs = torch.ones_like(probs) / probs.shape[-1]
            
            # Create categorical distribution
            dist = Categorical(probs)
            
            # Get value from critic
            value = self.critic(state)
            
            # Sample action
            action = dist.sample()
            log_prob = dist.log_prob(action)
            
        self.actor.train()  # Set back to training mode
        return action.item(), log_prob.item(), value.item()

    def learn(self):
        for _ in range(self.n_epochs):
            state_arr, action_arr, old_prob_arr, vals_arr,\
            reward_arr, dones_arr, batches = \
                    self.memory.generate_batches(self.batch_size)

            values = vals_arr
            advantage = np.zeros(len(reward_arr), dtype=np.float32)

            # Calculate advantages
            for t in range(len(reward_arr)-1):
                discount = 1
                a_t = 0
                for k in range(t, len(reward_arr)-1):
                    a_t += discount*(reward_arr[k] + self.gamma*values[k+1]*\
                            (1-int(dones_arr[k])) - values[k])
                    discount *= self.gamma*self.gae_lambda
                advantage[t] = a_t

            advantage = torch.tensor(advantage).to(self.device)
            values = torch.tensor(values).to(self.device)

            for batch in batches:
                states = torch.tensor(state_arr[batch], dtype=torch.float).to(self.device)
                old_probs = torch.tensor(old_prob_arr[batch]).to(self.device)
                actions = torch.tensor(action_arr[batch]).to(self.device)

                # Get action probabilities from actor
                probs = self.actor(states)
                
                # Debugging: Check for NaN values
                if torch.isnan(probs).any() or torch.isinf(probs).any():
                    print("NaN or Inf detected in batch action probabilities")
                    probs = torch.ones_like(probs) / probs.shape[-1]
                
                # Create categorical distribution
                dist = Categorical(probs)
                critic_value = self.critic(states)
                critic_value = torch.squeeze(critic_value)

                # Calculate new log probabilities
                new_probs = dist.log_prob(actions)
                prob_ratio = new_probs.exp() / old_probs.exp()
                weighted_probs = advantage[batch] * prob_ratio
                weighted_clipped_probs = torch.clamp(prob_ratio, 1-self.policy_clip,
                        1+self.policy_clip)*advantage[batch]
                
                # Add entropy bonus to actor loss only
                entropy = dist.entropy().mean()
                actor_loss = -torch.min(weighted_probs, weighted_clipped_probs).mean() - 0.01 * entropy

                returns = advantage[batch] + values[batch]
                critic_loss = (returns-critic_value)**2
                critic_loss = critic_loss.mean()

                total_loss = actor_loss + 0.5*critic_loss
                
                self.actor_optimizer.zero_grad()
                self.critic_optimizer.zero_grad()
                total_loss.backward()
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(self.actor.parameters(), self.max_grad_norm)
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.max_grad_norm)
                
                self.actor_optimizer.step()
                self.critic_optimizer.step()

        self.memory.clear_memory()

In [5]:
import os
import traceback

def train_ppo():
    # Hyperparameters
    BATCH_SIZE = 512
    GAMMA = 0.99
    BUFFER_SIZE = 2048
    MIN_BUFFER_SIZE = 512
    N_BATTLES = 2500
    
    INITIAL_LR = 0.0003
    MIN_LR = 0.00001
    
    MODEL_DIR = "saved_models_ppo_final_bestsmartdamage"
    os.makedirs(MODEL_DIR, exist_ok=True)
    
    # Evaluation metrics
    best_eval_reward = float('-inf')
    best_win_rate = 0
    best_combined_reward = float('-inf')
    EVAL_FREQUENCY = 25    
    N_EVAL_BATTLES = 25     
    
    # Initialize PPO agent
    input_dims = 1148
    n_actions = 22   
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    agent = PPOAgent(
        input_dims=input_dims, 
        n_actions=n_actions,
        lr=INITIAL_LR,
        batch_size=BATCH_SIZE,
        gamma=GAMMA,
        n_epochs=4
    )
    
    load_model_path = "saved_models_ppo_final_bestbasepower/latest_model_ppo.pth"
    
    start_battle = 0
    if load_model_path and os.path.exists(load_model_path):
        checkpoint = torch.load(load_model_path)
        agent.actor.load_state_dict(checkpoint['actor_state_dict'])
        agent.critic.load_state_dict(checkpoint['critic_state_dict'])
        #agent.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
        #agent.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
        start_battle = checkpoint['battle'] + 1
        best_eval_reward_old = checkpoint['avg_reward']
        best_win_rate_old = checkpoint['win_rate']
        best_combined_reward_old = checkpoint['avg_reward'] + checkpoint['win_rate']
        print(f"Loaded model from battle {start_battle}")
        print(f"Previous best reward: {best_eval_reward_old:.2f}")
        print(f"Previous best win rate: {best_win_rate_old:.3f}")
    
    # Learning rate scheduler
    actor_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        agent.actor_optimizer,
        mode='max',           
        factor=0.75,         # Less aggressive reduction (was 0.5)
        patience=5,          # More patience before reducing (was 3)
        threshold=0.02,      # Higher threshold (was 0.01)
        min_lr=MIN_LR,
        verbose=True
    )
    
    critic_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        agent.critic_optimizer,         
        mode='max',           
        factor=0.75,         # Less aggressive reduction (was 0.5)
        patience=5,          # More patience before reducing (was 3)
        threshold=0.02,      # Higher threshold (was 0.01)
        min_lr=MIN_LR,
        verbose=True
    )
        
    # Create opponents
    random_opponent = RandomPlayer(battle_format="gen8randombattle")
    #max_damage_opponent = BestBasePowerPlayer(battle_format="gen8randombattle")
    smart_damage_opponent = SmartDamagePlayer(battle_format="gen8randombattle")
    
    #opponent2 = RandomPlayer(battle_format="gen8randombattle")
    env = SimpleRLPlayer(battle_format="gen8randombattle", opponent=smart_damage_opponent, start_challenging=True)
    
    # Training metrics
    rewards_history = []
    win_history = []
    eval_rewards_history = []
    eval_winrate_history = []
    eval_combined_history = []
    
    # For tracking average reward
    #reward_window = deque(maxlen=100)
    
    try:
        for battle in range(N_BATTLES):
            battle_done = False
            state_tuple = env.reset()
            state = state_tuple[0] if isinstance(state_tuple, tuple) else state_tuple
            state = np.array(state, dtype=np.float32)
            input_dims = len(state)
            #print(f"Actual input dimensions: {input_dims}")
            total_reward = 0
            
            while not battle_done:
                # Get action from PPO agent
                action, prob, val = agent.choose_action(state)
                
                # Take action and observe result
                next_state, reward, done, truncated, info = env.step(action)
                battle_done = done or truncated
                next_state = np.array(next_state, dtype=np.float32)
                total_reward += reward
                
                # Store transition
                agent.remember(state, action, prob, val, reward, done)
                state = next_state
                
                # Train if enough samples
                if len(agent.memory.states) >= MIN_BUFFER_SIZE:
                    agent.learn()
            
            # Update metrics
            current_actor_lr = agent.actor_optimizer.param_groups[0]['lr']
            current_critic_lr = agent.critic_optimizer.param_groups[0]['lr']
            
            print(f"Battle {battle + 1}/{N_BATTLES}, "
                  f"Total Reward: {total_reward}, "
                  f"Actor LR: {current_actor_lr:.9f}, "
                  f"Critic LR: {current_critic_lr:.9f}")
            
            # Store metrics
            rewards_history.append(total_reward)
            win_history.append(1 if total_reward > 250 else 0)
            
            # Periodic evaluation
            if (battle + 1) % EVAL_FREQUENCY == 0:
                avg_eval_reward, win_rate = evaluate_ppo_agent(agent, N_EVAL_BATTLES)
                eval_rewards_history.append(avg_eval_reward)
                eval_winrate_history.append(win_rate)
                eval_combined_history.append(avg_eval_reward + win_rate)
                
                # Update learning rates
                combined_metric = avg_eval_reward + win_rate
                actor_scheduler.step(combined_metric)
                critic_scheduler.step(combined_metric)
                
                print(f"\nEvaluation at battle {battle + 1}:")
                print(f"Average Evaluation Reward: {avg_eval_reward:.2f}")
                print(f"Win Rate: {win_rate:.3f}")
                print(f"Combined Metric: {combined_metric:.2f}")
                
                # Save latest model
                torch.save({
                    'battle': battle,
                    'actor_state_dict': agent.actor.state_dict(),
                    'critic_state_dict': agent.critic.state_dict(),
                    'actor_optimizer_state_dict': agent.actor_optimizer.state_dict(),
                    'critic_optimizer_state_dict': agent.critic_optimizer.state_dict(),
                    'avg_reward': avg_eval_reward,
                    'win_rate': win_rate
                }, os.path.join(MODEL_DIR, 'latest_model_ppo.pth'))
                
                # Save best models based on different metrics
                if avg_eval_reward > best_eval_reward:
                    best_eval_reward = avg_eval_reward
                    torch.save({
                        'battle': battle,
                        'actor_state_dict': agent.actor.state_dict(),
                        'critic_state_dict': agent.critic.state_dict(),
                        'actor_optimizer_state_dict': agent.actor_optimizer.state_dict(),
                        'critic_optimizer_state_dict': agent.critic_optimizer.state_dict(),
                        'avg_reward': avg_eval_reward,
                        'win_rate': win_rate
                    }, os.path.join(MODEL_DIR, 'best_reward_model_ppo.pth'))
                
                if win_rate > best_win_rate:
                    best_win_rate = win_rate
                    torch.save({
                        'battle': battle,
                        'actor_state_dict': agent.actor.state_dict(),
                        'critic_state_dict': agent.critic.state_dict(),
                        'actor_optimizer_state_dict': agent.actor_optimizer.state_dict(),
                        'critic_optimizer_state_dict': agent.critic_optimizer.state_dict(),
                        'avg_reward': avg_eval_reward,
                        'win_rate': win_rate
                    }, os.path.join(MODEL_DIR, 'best_winrate_model_ppo.pth'))
                
                if combined_metric > best_combined_reward:
                    best_combined_reward = combined_metric
                    torch.save({
                        'battle': battle,
                        'actor_state_dict': agent.actor.state_dict(),
                        'critic_state_dict': agent.critic.state_dict(),
                        'actor_optimizer_state_dict': agent.actor_optimizer.state_dict(),
                        'critic_optimizer_state_dict': agent.critic_optimizer.state_dict(),
                        'avg_reward': avg_eval_reward,
                        'win_rate': win_rate
                    }, os.path.join(MODEL_DIR, 'best_combined_model_ppo.pth'))
                
                print(f"Best Evaluation Reward so far: {best_eval_reward:.2f}")
                print(f"Best Win Rate so far: {best_win_rate:.3f}")
                print(f"Best Combined Metric so far: {best_combined_reward:.2f}")
                print("--------------------")
            
    except Exception as e:
        print(f"An error occurred: {e}")
        print(f"Traceback: {traceback.format_exc()}")
            
    finally:
        env.close()
        
    return agent, rewards_history, win_history, eval_rewards_history, eval_winrate_history, eval_combined_history

In [6]:
def evaluate_ppo_agent(agent, n_eval_battles=25):
    """
    Evaluate the PPO agent during training
    """
    agent.actor.eval()
    agent.critic.eval()
    
    #opponent = RandomPlayer(battle_format="gen8randombattle")
    #opponent = BestBasePowerPlayer(battle_format="gen8randombattle")
    opponent = SmartDamagePlayer(battle_format="gen8randombattle")
    env = SimpleRLPlayer(battle_format="gen8randombattle", opponent=opponent, start_challenging=True)
    
    total_rewards = 0
    wins = 0
    
    try:
        for _ in range(n_eval_battles):
            state_tuple = env.reset()
            state = state_tuple[0] if isinstance(state_tuple, tuple) else state_tuple
            state = np.array(state, dtype=np.float32)
            battle_done = False
            battle_reward = 0
            
            while not battle_done:
                with torch.no_grad():
                    action, _, _ = agent.choose_action(state)
                
                next_state, reward, done, truncated, info = env.step(action)
                battle_done = done or truncated
                battle_reward += reward
                state = next_state
            
            total_rewards += battle_reward
            if battle_reward > 250:
                wins += 1
    
    finally:
        env.close()
        agent.actor.train()
        agent.critic.train()
    
    avg_reward = total_rewards / n_eval_battles
    win_rate = wins / n_eval_battles
    return avg_reward, win_rate

In [7]:
trained_ppo_agent = train_ppo()


Loaded model from battle 1250
Previous best reward: 16.83
Previous best win rate: 0.280
Battle 1/2500, Total Reward: 965.0642579713062, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 2/2500, Total Reward: -323.0339041999265, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 3/2500, Total Reward: -404.78622586056576, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 4/2500, Total Reward: -388.6244732299567, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 5/2500, Total Reward: -471.35436046640245, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 6/2500, Total Reward: -522.2812134609333, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 7/2500, Total Reward: -558.9581653002324, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 8/2500, Total Reward: -430.2784109892853, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 9/2500, Total Reward: -467.80484828296017, Actor LR: 0.000300000, Critic LR: 0.000300000
Battle 10/2500, Total Reward: -140.545318251573, Actor 

: 

In [13]:
import torch
def load_ppo_model(model_path):
    """
    Load a saved PPO model
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create a new agent with same architecture
    agent = PPOAgent(input_dims=1148, n_actions=22)
    
    # Load the saved state
    checkpoint = torch.load(model_path)
    
    # Load states for both networks
    agent.actor.load_state_dict(checkpoint['actor_state_dict'])
    agent.critic.load_state_dict(checkpoint['critic_state_dict'])
    
    # Optionally load optimizer states if you plan to continue training
    agent.actor_optimizer.load_state_dict(checkpoint['actor_optimizer_state_dict'])
    agent.critic_optimizer.load_state_dict(checkpoint['critic_optimizer_state_dict'])
    
    print(f"Loaded model from battle {checkpoint['battle']}")
    print(f"Saved metrics - Reward: {checkpoint['avg_reward']:.2f}, Win Rate: {checkpoint['win_rate']:.3f}")
    
    return agent

def evaluate_loaded_model(agent, n_battles=100, use_model=True):
    """
    Evaluate a loaded PPO model
    """
    # Set networks to evaluation mode
    agent.actor.eval()
    agent.critic.eval()
    
    # Create a new environment for evaluation
    #opponent = RandomPlayer(battle_format="gen8randombattle")
    #opponent = BestBasePowerPlayer(battle_format="gen8randombattle")
    opponent = SmartDamagePlayer(battle_format="gen8randombattle")
    env = SimpleRLPlayer(battle_format="gen8randombattle", opponent=opponent, start_challenging=True)
    
    total_rewards = 0
    wins = 0
    
    try:
        for battle in range(n_battles):
            state_tuple = env.reset()
            state = state_tuple[0] if isinstance(state_tuple, tuple) else state_tuple
            state = np.array(state, dtype=np.float32)
            done = False
            battle_reward = 0
            
            while not done:
                if use_model:
                    with torch.no_grad():
                        action, _, _ = agent.choose_action(state)
                else:
                    action = env.action_space.sample()
                
                next_state, reward, done, truncated, info = env.step(action)
                battle_reward += reward
                state = next_state
                done = done or truncated
            
            total_rewards += battle_reward
            if battle_reward > 30:  # Assuming victory_value=30.0
                wins += 1
            
            print(f"Battle {battle + 1}/{n_battles}, Reward: {battle_reward:.2f}")
        
        win_rate = (wins / n_battles) * 100
        avg_reward = total_rewards / n_battles
        print(f"\nEvaluation Results:")
        print(f"Average Reward: {avg_reward:.2f}")
        print(f"Win Rate: {win_rate:.2f}%")
        print(f"Wins: {wins}/{n_battles}")
    
    finally:
        env.close()
        # Set networks back to training mode
        agent.actor.train()
        agent.critic.train()

# Usage example:
model_path = 'saved_models_ppo_final_bestbasepower/latest_model_ppo.pth'  # or whichever model you want to load
loaded_agent = load_ppo_model(model_path)

# Evaluate with model's actions
print("\nEvaluating with model's actions:")
evaluate_loaded_model(loaded_agent, n_battles=50, use_model=True)

## Optionally, evaluate with random actions for comparison
#print("\nEvaluating with random actions:")
#evaluate_loaded_model(loaded_agent, n_battles=50, use_model=False)

Loaded model from battle 1249
Saved metrics - Reward: 16.83, Win Rate: 0.280

Evaluating with model's actions:
Battle 1/50, Reward: -271.24
Battle 2/50, Reward: -507.44
Battle 3/50, Reward: -497.54
Battle 4/50, Reward: 1552.99
Battle 5/50, Reward: -534.28
Battle 6/50, Reward: 653.40
Battle 7/50, Reward: -562.00
Battle 8/50, Reward: 831.45
Battle 9/50, Reward: -669.03
Battle 10/50, Reward: -264.88
Battle 11/50, Reward: -544.59
Battle 12/50, Reward: -574.05
Battle 13/50, Reward: -586.81
Battle 14/50, Reward: -474.62
Battle 15/50, Reward: -423.15
Battle 16/50, Reward: -428.91
Battle 17/50, Reward: -511.64
Battle 18/50, Reward: -367.84
Battle 19/50, Reward: 810.33
Battle 20/50, Reward: 131.40
Battle 21/50, Reward: -441.21
Battle 22/50, Reward: -558.94
Battle 23/50, Reward: -137.34
Battle 24/50, Reward: -606.19
Battle 25/50, Reward: -583.36
Battle 26/50, Reward: -818.71
Battle 27/50, Reward: -530.76
Battle 28/50, Reward: -501.64
Battle 29/50, Reward: -429.09
Battle 30/50, Reward: -577.89
Ba

SyntaxError: invalid syntax (2432080017.py, line 1)